/*
 * Copyright 2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.alibaba.cloud.ai.node;

import com.alibaba.cloud.ai.connector.accessor.Accessor;
import com.alibaba.cloud.ai.connector.bo.DbQueryParameter;
import com.alibaba.cloud.ai.connector.bo.ResultSetBO;
import com.alibaba.cloud.ai.connector.config.DbConfig;
import com.alibaba.cloud.ai.constant.Constant;

import com.alibaba.cloud.ai.enums.StreamResponseType;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.model.execution.ExecutionStep;
import com.alibaba.cloud.ai.service.DatasourceService;
import com.alibaba.cloud.ai.entity.AgentDatasource;
import com.alibaba.cloud.ai.entity.Datasource;
import com.alibaba.cloud.ai.util.ChatResponseUtil;
import com.alibaba.cloud.ai.util.StateUtils;
import com.alibaba.cloud.ai.util.StepResultUtils;
import com.alibaba.cloud.ai.util.StreamingChatGeneratorUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatResponse;
import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.alibaba.cloud.ai.constant.Constant.SQL_EXECUTE_NODE_EXCEPTION_OUTPUT;
import static com.alibaba.cloud.ai.constant.Constant.SQL_EXECUTE_NODE_OUTPUT;

/**
 * SQL execution node that executes SQL queries against the database.
 *
 * This node is responsible for: - Executing SQL queries generated by previous nodes -
 * Handling query results and errors - Providing streaming feedback to users during
 * execution - Managing step-by-step result accumulation
 *
 * @author zhangshenghang
 */
public class SqlExecuteNode extends AbstractPlanBasedNode {

	private static final Logger logger = LoggerFactory.getLogger(SqlExecuteNode.class);

	private final Accessor dbAccessor;

	private final DatasourceService datasourceService;

	public SqlExecuteNode(Accessor dbAccessor, DatasourceService datasourceService) {
		super();
		this.dbAccessor = dbAccessor;
		this.datasourceService = datasourceService;
	}

	@Override
	public Map<String, Object> apply(OverAllState state) throws Exception {
		logNodeEntry();

		ExecutionStep executionStep = getCurrentExecutionStep(state);
		Integer currentStep = getCurrentStepNumber(state);

		ExecutionStep.ToolParameters toolParameters = executionStep.getToolParameters();
		String sqlQuery = toolParameters.getSqlQuery();

		logger.info("Executing SQL query: {}", sqlQuery);
		logger.info("Step description: {}", toolParameters.getDescription());

		// Dynamically get the data source configuration for an agent
		DbConfig dbConfig = getAgentDbConfig(state);

		return executeSqlQuery(state, currentStep, sqlQuery, dbConfig);
	}

	/**
	 * Dynamically get the data source configuration for an agent
	 * @param state The state object containing the agent ID
	 * @return The database configuration corresponding to the agent
	 * @throws RuntimeException If the agent has no enabled data source configured
	 */
	private DbConfig getAgentDbConfig(OverAllState state) {
		try {
			// Get the agent ID from the state
			String agentIdStr = StateUtils.getStringValue(state, Constant.AGENT_ID);
			if (agentIdStr == null || agentIdStr.trim().isEmpty()) {
				throw new RuntimeException("未找到智能体ID，无法获取数据源配置");
			}

			Integer agentId = Integer.valueOf(agentIdStr);
			logger.info("Getting datasource config for agent: {}", agentId);

			// Get the enabled data source for the agent
			List<AgentDatasource> agentDatasources = datasourceService.getAgentDatasources(agentId);
			if (agentDatasources.size() == 0) {
				// TODO 调试AgentID不一致，暂时手动处理
				agentDatasources = datasourceService.getAgentDatasources(agentId - 999999);
			}
			AgentDatasource activeDatasource = agentDatasources.stream()
				.filter(ad -> ad.getIsActive() == 1)
				.findFirst()
				.orElseThrow(() -> new RuntimeException("智能体 " + agentId + " 未配置启用的数据源"));

			// Convert to DbConfig
			DbConfig dbConfig = createDbConfigFromDatasource(activeDatasource.getDatasource());
			logger.info("Successfully created DbConfig for agent {}: url={}, schema={}, type={}", agentId,
					dbConfig.getUrl(), dbConfig.getSchema(), dbConfig.getDialectType());

			return dbConfig;
		}
		catch (Exception e) {
			logger.error("Failed to get agent datasource config", e);
			throw new RuntimeException("获取智能体数据源配置失败: " + e.getMessage(), e);
		}
	}

	/**
	 * Create database configuration from data source entity
	 * @param datasource The data source entity
	 * @return The database configuration object
	 */
	private DbConfig createDbConfigFromDatasource(Datasource datasource) {
		DbConfig dbConfig = new DbConfig();

		// Set basic connection information
		dbConfig.setUrl(datasource.getConnectionUrl());
		dbConfig.setUsername(datasource.getUsername());
		dbConfig.setPassword(datasource.getPassword());

		// Set database type
		if ("mysql".equalsIgnoreCase(datasource.getType())) {
			dbConfig.setConnectionType("jdbc");
			dbConfig.setDialectType("mysql");
		}
		else if ("postgresql".equalsIgnoreCase(datasource.getType())) {
			dbConfig.setConnectionType("jdbc");
			dbConfig.setDialectType("postgresql");
		}
		else {
			throw new RuntimeException("不支持的数据库类型: " + datasource.getType());
		}

		// Set Schema to the database name of the data source
		dbConfig.setSchema(datasource.getDatabaseName());

		return dbConfig;
	}

	/**
	 * Executes the SQL query against the database and handles the results.
	 *
	 * This method follows the business-logic-first pattern: 1. Execute the actual SQL
	 * query immediately 2. Process and store the results 3. Create streaming output for
	 * user experience only
	 * @param state The overall state containing execution context
	 * @param currentStep The current step number in the execution plan
	 * @param sqlQuery The SQL query to execute
	 * @param dbConfig The database configuration to use for execution
	 * @return Map containing the generator for streaming output
	 */
	@SuppressWarnings("unchecked")
	private Map<String, Object> executeSqlQuery(OverAllState state, Integer currentStep, String sqlQuery,
			DbConfig dbConfig) {
		// Execute business logic first - actual SQL execution
		DbQueryParameter dbQueryParameter = new DbQueryParameter();
		dbQueryParameter.setSql(sqlQuery);

		try {
			// Execute SQL query and get results immediately
			ResultSetBO resultSetBO = dbAccessor.executeSqlAndReturnObject(dbConfig, dbQueryParameter);
			String jsonStr = resultSetBO.toJsonStr();

			// Update step results with the query output
			Map<String, String> existingResults = StateUtils.getObjectValue(state, SQL_EXECUTE_NODE_OUTPUT, Map.class,
					new HashMap<>());
			Map<String, String> updatedResults = StepResultUtils.addStepResult(existingResults, currentStep, jsonStr);

			logger.info("SQL execution successful, result count: {}",
					resultSetBO.getData() != null ? resultSetBO.getData().size() : 0);

			// Prepare the final result object
			// Store List of SQL query results for use by code execution node
			Map<String, Object> result = Map.of(SQL_EXECUTE_NODE_OUTPUT, updatedResults,
					SQL_EXECUTE_NODE_EXCEPTION_OUTPUT, "", Constant.SQL_RESULT_LIST_MEMORY, resultSetBO.getData());

			// Create display flux for user experience only
			Flux<ChatResponse> displayFlux = Flux.create(emitter -> {
				emitter.next(ChatResponseUtil.createStatusResponse("开始执行SQL..."));
				emitter.next(ChatResponseUtil.createStatusResponse("执行SQL查询"));
				emitter.next(ChatResponseUtil.createStatusResponse("```" + sqlQuery + "```"));
				emitter.next(ChatResponseUtil.createStatusResponse("执行SQL完成"));
				emitter.complete();
			});

			// Create generator using utility class, returning pre-computed business logic
			// result
			var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state,
					v -> result, displayFlux, StreamResponseType.EXECUTE_SQL);

			return Map.of(SQL_EXECUTE_NODE_OUTPUT, generator);
		}
		catch (Exception e) {
			String errorMessage = e.getMessage();
			logger.error("SQL execution failed - SQL: [{}] ", sqlQuery, e);

			// Prepare error result
			Map<String, Object> errorResult = Map.of(SQL_EXECUTE_NODE_EXCEPTION_OUTPUT, errorMessage);

			// Create error display flux
			Flux<ChatResponse> errorDisplayFlux = Flux.create(emitter -> {
				emitter.next(ChatResponseUtil.createStatusResponse("开始执行SQL..."));
				emitter.next(ChatResponseUtil.createStatusResponse("执行SQL查询"));
				emitter.next(ChatResponseUtil.createStatusResponse("SQL执行失败: " + errorMessage));
				emitter.complete();
			});

			// Create error generator using utility class
			var generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), state,
					v -> errorResult, errorDisplayFlux, StreamResponseType.EXECUTE_SQL);

			return Map.of(SQL_EXECUTE_NODE_EXCEPTION_OUTPUT, generator);
		}
	}

}
